While machine learning on large datasets is the dominant paradigm in the field, there are a number of drawbacks to centrally aggregating data, namely privacy. Federated learning aims to address this and has shown promise for text completion tasks on mobile devices. The Tensorflow Federated API provides methods to train federated models and conduct federated learning experiments on data grouped by clients but never aggregated. Through our research partnership with Google, we aim to build on the existing body of federated learning experiments with a particular focus on enhancing text models for natural language understanding tasks, such as next word prediction.
Federated Learning aims to train machine learning models in a distributed fashion without centralizing data but instead updating and passing model parameters from a central server to distributed entities and back to perform stochastic gradient descent. McMahan et al. propose the Federated Averaging algorithm in “Communication-Efficient Learning of Deep Networks from Decentralized Data.” Our goal is to replicate the existing network architectures for Federated Averaging, stress testing their limits within our simulated environment. We then aim to apply pretraining and pretrained model layers to measure the impact of starting with learned model weights compared to random initialization for the task of next word prediction on Stack Overflow. Specifically, in this interim delivery, we measure the effect of using pretrained embeddings on the number of training rounds required to achieve a fixed level of accuracy.
The main dataset used for these experiments is hosted by Kaggle and made available through the tff.simulation.datasets module in the Tensorflow Federated API. Stack Overflow owns the data and has released the data under the CC BY-SA 3.0 license. The Stack Overflow data contains the full body text of all Stack Overflow questions and answers along with metadata, and the API pointer is updated quarterly. The data is split into the following sets at the time of writing:
The EDA notebook linked here contains an exploratory analysis of the data with example records and visualizations. From this notebook we deduce that challenges with the data include:
For this interim delivery, we train two neural networks with four layers each and compare train and validation accuracy at each training round for 500 training rounds by sampling 10 training client datasets per round, each with 5,000 non-IID text samples from Stack Overflow at maximum, and a total of 10,000 validation text samples. Each of the two models are trained with the Federated Averaging algorithm as in McMahan et. al. The model architecure is as follows:
Note in the above that one network starts with a randomly initialized embedding layer while the other starts with pretrained GloVe embeddings trained on the Wikipedia2015 + Gigaword5 text corpus. A majority but not all of the words in the Stack Overflow vocabulary have corresponding GloVe embeddings. For this reason, we set this layer to trainable to learn embeddings for words without GloVe representations (~10%) and fine tune embeddings with existing representations.
With the model design for the two networks fixed with the exception of the starting embedding layers, we train the model using the Adam optimizer and Sparse Categorical Crossentropy loss, measuring accuracy including out of vocab and end of sentence tokens after each training round on both the sampled training client datasets and the fixed validation set. Note that epochs and training rounds are equivalent as we apply federated averaging after each round as opposed to applying optimization steps on each client dataset for multiple epochs in between training rounds.
We see that even with a relatively small embedding dimension (100 units), the model with pretrained word embeddings achieves faster convergence during federated training. We find this result promising and aim to expand on this analysis through more experimentation. In particular, we aim to measure how pretraining an entire network centrally and fine tuning it in the federated context compares to federated training from scratch and federated training with pretrained word embeddings. We aim to also explore other model configurations for handling text sequences.
Here we include the original schedule from our project proposal and a modified version of that schedule given our work to date.
January:
February:
March:
January:
February:
March:
In the data section we highlite a couple of challenges. Additionally, we have found that our experiments quickly drain cloud credits and we are hopeful that we’ll be able to secure more credits in order to execute on the items in our revised schedule. We have also found reconciling standard Tensorflow modeling approaches with the Tensorflow Federated API to be somewhat challenging, as we would like to test a variety of approaches to training language models in the federated context. That said, we view our analysis and code toward this effort to be one of our primary contributions, and we hope that others find our work instructive for working with text models using the Tensorflow Federated API.
This project draws mainly from the following sources, but other sources are referenced throughout this repository.